import sys
import numpy as np
from Hamiltonian_solver_2b_sweep import *
from multiprocessing import Pool

def run_SCF(U, consts):
    Hamiltonian_solver_sweep(consts["SWM"],consts["kmax"],consts["dk"],consts["N_lay"],U=0,seed_val=consts["seed_val"],H_renorm=consts["H_renorm"]) # to initialize global variables

    return SCF(U, consts).Mean_field_State()

def main():
    #  SWM Parameters in meV
    SWM = {
        "gamma_0": 3160,
        "gamma_1": 460,
        "gamma_2": -17,
        "gamma_3": -300,
        "gamma_4": -86,
        "delta":   -1.1,
    }

    # Parameters, Constants
    try:
        code = sys.argv[1]
    except:
        code = ''

    local_scope = {}
    exec(code, {}, local_scope)

    seed_val = local_scope.get('seed_val', {"K_up" : 0, "K_dn" : 0, "Kp_up" : 0, "Kp_dn" : 0, "vx" : 0, "sx": 0, "vsx": 0}) 
    if not("K_up"  in seed_val.keys()): seed_val["K_up"] = 0
    if not("K_dn"  in seed_val.keys()): seed_val["K_dn"] = 0        
    if not("Kp_up" in seed_val.keys()): seed_val["Kp_up"] = 0
    if not("Kp_dn" in seed_val.keys()): seed_val["Kp_dn"] = 0
    if not("vx"    in seed_val.keys()): seed_val["vx"] = 0
    if not("sx"    in seed_val.keys()): seed_val["sx"] = 0
    if not("vsx"   in seed_val.keys()): seed_val["vsx"] = 0

    rmt_SWM  = local_scope.get('rmt_SWM', 1)    # strength of remote hopping terms

    V_type   = local_scope.get('V_type', "short")    # Coulomb constant meV*cm
    if not(V_type == "short"):
        V_type = "long"

    alp       = local_scope.get('alp', 0)
    er        = local_scope.get('er', 15)
    Hund      = local_scope.get('Hund', 0)
    ke        = local_scope.get('ke', 1.44e-4)    # Coulomb constant meV*cm

    V0xnu_0  = local_scope.get('V0xnu_0', 0.1)    # Coulomb constant
    nu_0     = SWM["gamma_1"] / (3 * np.pi * SWM["gamma_0"]**2 * 2.46e-8**2);  # [meV^-1 * cm^-2]

    beta      = local_scope.get('beta', 58)       # 1/(K_B*T) in 1/meV (ind. parameter) (T=0.2 K)    

    SOC       = local_scope.get('SOC',  0)
    SOC_dir   = local_scope.get('SOC_dir', 0)
    
    mix       = local_scope.get('mix', 0.9)

    N_U       = local_scope.get('N_U',  0)
    U_fr      = local_scope.get('U_fr', 0)
    U_to      = local_scope.get('U_to', 60)
    U_val     = np.linspace(U_fr,U_to,N_U+1)

    N_ne      = local_scope.get('N_ne', 0)
    ne_fr     = local_scope.get('ne_fr', 0.01e12)
    ne_to     = local_scope.get('ne_to', 1.01e12)
    ne_val    = np.linspace(ne_fr,ne_to,N_ne+1)

    par_num   = local_scope.get('par_num', 0)     # # of physical cores

    dk        = local_scope.get('dk', 1.0e-3)     # in unit of 1/a (ind. parameter)
    kmax      = local_scope.get('kmax', 0.120)      # in unit of 1/a (ind. parameter)

    run_idx = local_scope.get('run_idx',  0)

    H_renorm = local_scope.get('H_renorm', True)
    fix_sym = local_scope.get('fix_sym',  '')
    calc_type = local_scope.get('calc_type', '')

    consts = {
        "N_lay": 3,
        "SWM": SWM,
        "dk": dk,
        "kmax": kmax,

        "V_type":V_type,
        "V0": V0xnu_0/nu_0, # [meV * cm]
        "ke": ke,
        "er": er,
        "d_gate": 3.69e-6,
        "alp": alp,
        "Hund": Hund,
        "beta": beta,

        "SOC": SOC,
        "SOC_dir": SOC_dir,

        "max_count": 200,
        "tol": 1e-7,
        "mix": mix,
        "seed_val": seed_val,

        "ne_val": ne_val,

        "rmt_SWM": rmt_SWM,
        
        "run_idx": run_idx,

        "H_renorm": H_renorm,
        "fix_sym": fix_sym,
        "calc_type": calc_type
    }

    if N_U == 0 or par_num == 0:
        Hamiltonian_solver_sweep(SWM,kmax,dk,N_lay=3,U=0,seed_val=seed_val, H_renorm= H_renorm) # to initialize global variables 

        for U in U_val:
            SCF(U, consts).Mean_field_State()
    else:
        # multiprocessing

        # Create a pool of workers
        with Pool(processes=par_num) as pool:
            results = []

            # Submit tasks asynchronously
            for U in U_val:
                pool.apply_async(run_SCF, (U, consts))

            # Close the pool and wait for all tasks to complete
            pool.close()
            pool.join()

if __name__ == '__main__':
    main()